# Copyright (C) 2022 ByteDance Inc.
# All rights reserved.
# Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).

# The software is made available under Creative Commons BY-NC-SA 4.0 license
# by ByteDance Inc. You can use, redistribute, and adapt it
# for non-commercial purposes, as long as you (a) give appropriate credit
# by citing our paper, (b) indicate any changes that you've made,
# and (c) distribute any derivative works under the same license.

# THE AUTHORS DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
# IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
# DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
# OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import numpy as np
import torch
from scipy.interpolate import CubicSpline
import os
import shutil

color_map = {
    0: [239, 234, 90],
    1: [44, 105, 154],
    2: [4, 139, 168],
    3: [13, 179, 158],
    4: [131, 227, 119],
    5: [185, 231, 105],
    6: [107, 137, 198],
    7: [241, 196, 83],
    8: [242, 158, 76],
    9: [234, 114, 71],
    10: [215, 95, 155],
    11: [207, 113, 192],
    12: [0, 0, 0],
    13: [159, 89, 165],
    14: [142, 82, 172],
    15: [158, 115, 200],
    16: [116, 95, 159],
}


def get_truncation_mean(model, n_sample, sample_batch, device):
    truncation_latents = [
        torch.zeros((model.style_dim,), device=device, dtype=torch.float32)
        for _ in range(model.n_local)
    ]
    loop_cnt = n_sample // sample_batch
    assert n_sample % sample_batch == 0
    with torch.no_grad():
        for _ in range(loop_cnt):
            x = model.ipt_p_head(model.ipt)  # 1 2c
            mu = x[:, : model.style_dim]  # 1 c
            log_var = x[:, model.style_dim :]  # 1 c
            batch = sample_batch
            h = torch.zeros(
                model.rnn_layers,
                batch,
                model.rnn_hidden_dim,
                dtype=x.dtype,
                device=x.device,
            )
            c = torch.zeros(
                model.rnn_layers,
                batch,
                model.rnn_hidden_dim,
                dtype=x.dtype,
                device=x.device,
            )
            # create first input
            x = model.reparameterize(
                mu.expand(batch, -1), log_var.expand(batch, -1)
            ).unsqueeze(
                0
            )  # 1 b c

            for i in range(model.n_local):
                out, (h, c) = model.rnn(x, (h, c))
                out = model.p_head(out)  # b 2*c
                mu = out[:, :, : model.style_dim]
                log_var = out[:, :, model.style_dim :]
                # sampling
                x = model.reparameterize(mu[0], log_var[0])  # b c
                latent = model.style(x)  # b c
                truncation_latents[i] += latent.mean(dim=0)
                x = x.unsqueeze(0)
    return [t / loop_cnt for t in truncation_latents]


def safe_mkdir(outdir):
    if os.path.exists(outdir):
        y_n = (
            input(f"outdir: {outdir}\nalready exist, do you want to delete it? yes/no:")
            .strip()
            .lower()
        )
        if y_n == "yes":
            shutil.rmtree(outdir)
            os.makedirs(outdir)
        else:
            print("old dir is kept, file maybe overwrite!")
    else:
        os.makedirs(outdir)


def generate_img(
    model, styles, mean_latent=None, truncation=1.0, batch_size=16, *args, **kwargs
):
    images = []
    for head in range(0, styles.size(0), batch_size):
        images_, _ = model(
            [styles[head : head + batch_size]],
            input_is_latent=True,
            truncation=truncation,
            truncation_latent=mean_latent,
            *args,
            **kwargs,
        )
        images.append(images_)
    images = torch.cat(images, 0)
    return tensor2image(images)


def generate(
    model, styles, mean_latent=None, truncation=1.0, batch_size=16, *args, **kwargs
):
    images, segs = [], []
    for head in range(0, styles.size(0), batch_size):
        images_, segs_ = model(
            [styles[head : head + batch_size]],
            input_is_latent=True,
            truncation=truncation,
            truncation_latent=mean_latent,
            *args,
            **kwargs,
        )
        images.append(images_.detach().cpu())
        segs.append(segs_.detach().cpu())
    images, segs = torch.cat(images, 0), torch.cat(segs, 0)
    return tensor2image(images), tensor2seg(segs)


def tensor2image(tensor):
    images = tensor.cpu().clamp(-1, 1).permute(0, 2, 3, 1).numpy()
    images = images * 127.5 + 127.5
    images = images.astype(np.uint8)
    return images


def tensor2seg(sample_seg):
    seg_dim = sample_seg.size(1)
    sample_seg = torch.argmax(sample_seg, dim=1).detach().cpu().numpy()
    sample_mask = np.zeros(
        (sample_seg.shape[0], sample_seg.shape[1], sample_seg.shape[2], 3),
        dtype=np.uint8,
    )
    for key in range(seg_dim):
        sample_mask[sample_seg == key] = color_map[key]
    return sample_mask


def cubic_spline_interpolate(styles, step):
    device = styles.device
    styles = styles.detach().cpu().numpy()
    N, K, D = styles.shape
    x = np.linspace(0.0, 1.0, N)
    y = styles.reshape(N, K * D)
    spl = CubicSpline(x, y)
    x_out = np.linspace(0.0, 1.0, step)
    results = spl(x_out)  # Step x KD
    results = results.reshape(step, K, D)
    return torch.tensor(results, device=device).float()
